#include "POF/Gradient_generator_fullcal_sample_grad.h"
#include <thread>
#include <math.h>
#include <cmath>
#include <stdexcept>
#include <assert.h>

using namespace std;

extern double ccon_gradient[n_label][MAX_n_class][MAX_batch_size];

extern void flatten_ccon_gradient(int n_label, int batch_size, int n_class, double ccon_gradient_list[]);

void fullcalc_grad_ccon_sample_grad(struct NN_data* nn_data, struct User_parameter* user_parameter, struct Indata* indata, double ccon_gradient_list[], double clip) {

    for (int pm = 0; pm < n_label; pm++) {
        for (int j = 0; j < nn_data->batch_size; j++) {
            double approx_loss_class[nn_data->n_class];
            double expsum = 0;

            for (int i = 0; i < nn_data->n_class; i++) {
                int tmp = nn_data->normalclass[pm][j]; // KKY: check
                nn_data->normalclass[pm][j] = i;
                // KKY: assumption:normaltable, posterior are already set.
                approx_loss_class[i] = calc_approx_loss(nn_data, indata, user_parameter);
                // cout<<approx_loss_class[i]<<endl;
                expsum += exp(nn_data->ccon_w[pm][i][j]);
                nn_data->normalclass[pm][j] = tmp;
            }
            for (int i = 0; i < nn_data->n_class; i++) {
                ccon_gradient[pm][i][j] = 0;
                for (int k = 0; k < nn_data->n_class; k++) {
                    if (k == i)
                        ccon_gradient[pm][i][j] += approx_loss_class[k] * (exp(nn_data->ccon_w[pm][k][j]) / expsum);
                    ccon_gradient[pm][i][j] += -approx_loss_class[k] * (exp(nn_data->ccon_w[pm][k][j]) * exp(nn_data->ccon_w[pm][i][j]) / expsum / expsum);
                }
                // NO_UPDATE
                if (ccon_gradient[pm][i][j] > clip)ccon_gradient[pm][i][j] = clip;
                if (ccon_gradient[pm][i][j] < -clip)ccon_gradient[pm][i][j] = -clip;
            }
        }
    }

    flatten_ccon_gradient(n_label, nn_data->batch_size, nn_data->n_class, ccon_gradient_list);
}